{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ch 5.7 单层RBF神经网络\n",
    "根据式(5.18)和(5.19)，试构造一个能解决异或问题的单层RBF神经网络。\n",
    "\n",
    "Note：\n",
    "我也不知道怎么选择center。。。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def RBF(x, center_i, beta_i):\n",
    "    dist = np.sum(pow((x - center_i),2))\n",
    "    rho = np.exp(-beta_i*dist)\n",
    "    return rho\n",
    "\n",
    "# 输出为1\n",
    "class RBF_network:\n",
    "    def __init__(self):\n",
    "        self.hidden_num = 0\n",
    "        self.y = 0\n",
    "        \n",
    "    def createNN(self, input_num, hidden_num, learning_rate, center):\n",
    "        self.input_num = input_num\n",
    "        self.hidden_num = hidden_num\n",
    "        self.center = center\n",
    "        self.w = np.random.random(self.hidden_num)\n",
    "        self.rho = np.zeros(self.hidden_num)\n",
    "        self.beta = np.random.random(self.hidden_num)\n",
    "        self.lr = learning_rate\n",
    "        \n",
    "    def Predict(self, x):\n",
    "        self.y = 0\n",
    "        for i in range(self.hidden_num):\n",
    "            self.rho[i] = RBF(x, self.center[i], self.beta[i])\n",
    "            self.y += self.w[i] * self.rho[i]\n",
    "        return self.y\n",
    "    \n",
    "    def BackPropagate(self, x, y):\n",
    "        self.Predict(x)\n",
    "        grad = np.zeros(self.hidden_num)\n",
    "        for i in range(self.hidden_num):\n",
    "            # dE_k/dy_cap = (y_cap-y)  \n",
    "            # dE_k/dw = (y_cap-y)*rho[i] \n",
    "            # dE_k/d_beta = -(y_cap-y)*rho[i]w_i*||x-c_i||\n",
    "            grad[i] = (self.y - y) * self.rho[i]\n",
    "            self.w[i] -= self.lr * grad[i]\n",
    "            self.beta[i] += self.lr * grad[i] * self.w[i] * np.sum(pow((x - center[i]),2))\n",
    "                          \n",
    "    def trainNN(self, x, y):\n",
    "        error_list = []\n",
    "        for i in range(len(x)):\n",
    "            self.BackPropagate(x[i], y[i])\n",
    "            error = (self.y - y[i])**2\n",
    "            error_list.append(error/2)\n",
    "        print(error_list)\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_x = np.random.randint(0,2,(100,2))\n",
    "train_y = np.logical_xor(train_x[:,0],train_x[:,1])\n",
    "test_x = np.random.randint(0,2,(100,2))\n",
    "test_y = np.logical_xor(test_x[:,0],test_x[:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "center = np.array([[0,0],[0,1],[1,0],[1,1]])\n",
    "rbf = RBF_network()\n",
    "rbf.createNN(input_num = 2, hidden_num=4 , learning_rate=0.1, center=center)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.026450226471867982, 0.93391983970613945, 0.67544996981490701, 0.49421467657995199, 0.0095084276976672745, 0.018820964260078336, 0.013920393999916783, 0.0046275756050139175, 0.0094325318134220442, 0.21598277371049648, 0.0094097009112470067, 0.014885837478294155, 0.01100301518698724, 0.18509657593104051, 0.1368430542114428, 0.021149043320076746, 0.015719168706313603, 0.02234453883878329, 0.016532486213518806, 0.41357287812727056, 0.022110838487044038, 0.11632434375088635, 0.025569698757844272, 0.019029895317142921, 0.10360641780706445, 0.30256306649027848, 0.038738414905934507, 0.24270469368979977, 0.052489856196674241, 0.1698683920305816, 0.072029318536415352, 0.038813438255340886, 0.11627518963725403, 0.053018977392016074, 0.03960691251503172, 0.029579057417204475, 0.20155569523209454, 0.035009711515775921, 0.083583998172591156, 0.021182346901497222, 0.015813372354595654, 0.011802243087649267, 0.052001123967976813, 0.0065465240218205711, 0.10893845497740777, 0.0098571764816098358, 0.086522991164620658, 0.059001106755437197, 0.043855025864776319, 0.032578767016338391, 0.26312710492963581, 0.19659896413806802, 0.057087194533256093, 0.073374876192042485, 0.028614865113947664, 0.04873545164474076, 0.036283859635530179, 0.015265031576476195, 0.011379576294771225, 0.093702470512944835, 0.19112279230448914, 0.046803297468740984, 0.020923845225340663, 0.1676820502980724, 0.058898949257328082, 0.044419443570576894, 0.10674900649610061, 0.026659111657373464, 0.057555328863012546, 0.042966198538984542, 0.07330241202953508, 0.11159504241557354, 0.033343122287102331, 0.02524673446354753, 0.08458425257171448, 0.08627649532524935, 0.075970558615111428, 0.077409826571923121, 0.058515308495888069, 0.018848946247424156, 0.085590842580048329, 0.065332915524985671, 0.057347086222570016, 0.042800535997451643, 0.039165438963256388, 0.029256422156363544, 0.021857579829068843, 0.1070202842786508, 0.080593592539281819, 0.031781560112788661, 0.037530075229567784, 0.028023052698578594, 0.084640384749162911, 0.025521410776222785, 0.070581471127235576, 0.025552441790697208, 0.030356875391837266, 0.06784166191559203, 0.022815307312747109, 0.057063135472262792]\n"
     ]
    }
   ],
   "source": [
    "rbf.trainNN(train_x, train_y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy:  1.0\n"
     ]
    }
   ],
   "source": [
    "print('accuracy: ',np.sum([rbf.Predict(xx)>0.5 for xx in test_x] == test_y)/len(test_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
